{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 31、PyTorch GRU的原理及其手写复现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$\n",
    "\\begin{align*}\n",
    "r_t &= \\sigma(W_{ir}x_t + b_{ir} + W_{hr}h_{(t-1)} + b_{hr}) \\\\\n",
    "z_t &= \\sigma(W_{iz}x_t + b_{iz} + W_{hz}h_{(t-1)} + b_{hz}) \\\\\n",
    "n_t &= \\tanh(W_{in}x_t + b_{in} + r_t \\odot (W_{hn}h_{(t-1)} + b_{hn})) \\\\\n",
    "h_t &= (1 - z_t) \\odot n_t + z_t \\odot h_{(t-1)}\n",
    "\\end{align*}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "from torch import Tensor\n",
    "\n",
    "\n",
    "def gru_forward(\n",
    "    input: Tensor,\n",
    "    init: Tensor,\n",
    "    w_ih: Tensor,\n",
    "    w_hh: Tensor,\n",
    "    b_ih: Tensor,\n",
    "    b_hh: Tensor,\n",
    "):\n",
    "    prev_h = init\n",
    "    bs, T, i_size = input.shape\n",
    "    h_size = w_ih.shape[0] // 3\n",
    "\n",
    "    # reset, update, new\n",
    "    w_ir, w_iz, w_in = w_ih.chunk(3, 0)\n",
    "    w_hr, w_hz, w_hn = w_hh.chunk(3, 0)\n",
    "\n",
    "    b_ir, b_iz, b_in = b_ih.chunk(3, 0)\n",
    "    b_hr, b_hz, b_hn = b_hh.chunk(3, 0)\n",
    "\n",
    "    output = torch.zeros(bs, T, h_size)\n",
    "\n",
    "    for t in range(T):\n",
    "        x = input[:, t, :]\n",
    "        r_t = torch.sigmoid((x @ w_ir.T + b_ir) + (prev_h @ w_hr.T + b_hr))\n",
    "        z_t = torch.sigmoid((x @ w_iz.T + b_iz) + (prev_h @ w_hz.T + b_hz))\n",
    "        n_t = torch.tanh((x @ w_in.T + b_in) + r_t * (prev_h @ w_hn.T + b_hn))\n",
    "        h_t = (1 - z_t) * n_t + z_t * prev_h\n",
    "\n",
    "        prev_h = h_t\n",
    "        output[:, t, :] = h_t\n",
    "\n",
    "    return output, prev_h"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight_ih_l0\ttorch.Size([12, 3])\n",
      "weight_hh_l0\ttorch.Size([12, 4])\n",
      "bias_ih_l0\ttorch.Size([12])\n",
      "bias_hh_l0\ttorch.Size([12])\n",
      "---------\n"
     ]
    }
   ],
   "source": [
    "batch_size, seq_len = 2, 3\n",
    "input_size, hidden_size = 3, 4\n",
    "input = torch.randn((batch_size, seq_len, input_size))\n",
    "h0 = torch.zeros(batch_size, hidden_size)\n",
    "gru_layer = nn.GRU(input_size, hidden_size, batch_first=True)\n",
    "out1, h1 = gru_layer.forward(input)\n",
    "for k, v in gru_layer.named_parameters():\n",
    "    print(f\"{k}\\t{v.shape}\")\n",
    "print(\"---------\")\n",
    "out2, h2 = gru_forward(\n",
    "    input,\n",
    "    h0,\n",
    "    gru_layer.weight_ih_l0,\n",
    "    gru_layer.weight_hh_l0,\n",
    "    gru_layer.bias_ih_l0,\n",
    "    gru_layer.bias_hh_l0\n",
    ")\n",
    "assert torch.allclose(out1, out2)\n",
    "assert torch.allclose(h1, h2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
